from typing import Dict, Any, List
from functools import partial

import torch
from torch import Tensor
from torch import nn
from torch.distributions import Normal, Independent

from ding.torch_utils import to_device, fold_batch, unfold_batch, unsqueeze_repeat
from ding.utils import POLICY_REGISTRY
from ding.policy import SACPolicy
from ding.rl_utils import generalized_lambda_returns
from ding.policy.common_utils import default_preprocess_learn

from .utils import q_evaluation


@POLICY_REGISTRY.register('mbsac')
class MBSACPolicy(SACPolicy):
    """
    Overview:
        Model based SAC with value expansion (arXiv: 1803.00101)
        and value gradient (arXiv: 1510.09142) w.r.t lambda-return.

        https://arxiv.org/pdf/1803.00101.pdf
        https://arxiv.org/pdf/1510.09142.pdf

    Config:
        == ====================   ========    =============  ==================================
        ID Symbol                 Type        Default Value  Description
        == ====================   ========    =============  ==================================
        1  ``learn._lambda``      float       0.8            | Lambda for TD-lambda return.
        2  ``learn.grad_clip`     float       100.0          | Max norm of gradients.
        3  | ``learn.sample``     bool        True           | Whether to sample states or
           | ``_state``                                      | transitions from env buffer.
        == ====================   ========    =============  ==================================

    .. note::
        For other configs, please refer to ding.policy.sac.SACPolicy.
    """

    config = dict(
        learn=dict(
            # (float) Lambda for TD-lambda return.
            lambda_=0.8,
            # (float) Max norm of gradients.
            grad_clip=100,
            # (bool) Whether to sample states or transitions from environment buffer.
            sample_state=True,
        )
    )

    def _init_learn(self) -> None:
        super()._init_learn()
        self._target_model.requires_grad_(False)

        self._lambda = self._cfg.learn.lambda_
        self._grad_clip = self._cfg.learn.grad_clip
        self._sample_state = self._cfg.learn.sample_state
        self._auto_alpha = self._cfg.learn.auto_alpha
        # TODO: auto alpha
        assert not self._auto_alpha, "NotImplemented"

        # TODO: TanhTransform leads to NaN
        def actor_fn(obs: Tensor):
            # (mu, sigma) = self._learn_model.forward(
            #     obs, mode='compute_actor')['logit']
            # # enforce action bounds
            # dist = TransformedDistribution(
            #     Independent(Normal(mu, sigma), 1), [TanhTransform()])
            # action = dist.rsample()
            # log_prob = dist.log_prob(action)
            # return action, -self._alpha.detach() * log_prob
            (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit']
            dist = Independent(Normal(mu, sigma), 1)
            pred = dist.rsample()
            action = torch.tanh(pred)

            log_prob = dist.log_prob(
                pred
            ) + 2 * (pred + torch.nn.functional.softplus(-2. * pred) - torch.log(torch.tensor(2.))).sum(-1)
            return action, -self._alpha.detach() * log_prob

        self._actor_fn = actor_fn

        def critic_fn(obss: Tensor, actions: Tensor, model: nn.Module):
            eval_data = {'obs': obss, 'action': actions}
            q_values = model.forward(eval_data, mode='compute_critic')['q_value']
            return q_values

        self._critic_fn = critic_fn
        self._forward_learn_cnt = 0

    def _forward_learn(self, data: dict, world_model, envstep) -> Dict[str, Any]:
        # preprocess data
        data = default_preprocess_learn(
            data,
            use_priority=self._priority,
            use_priority_IS_weight=self._cfg.priority_IS_weight,
            ignore_done=self._cfg.learn.ignore_done,
            use_nstep=False
        )
        if self._cuda:
            data = to_device(data, self._device)

        if len(data['action'].shape) == 1:
            data['action'] = data['action'].unsqueeze(1)

        self._learn_model.train()
        self._target_model.train()

        # TODO: use treetensor
        # rollout length is determined by world_model.rollout_length_scheduler
        if self._sample_state:
            # data['reward'], ... are not used
            obss, actions, rewards, aug_rewards, dones = \
                world_model.rollout(data['obs'], self._actor_fn, envstep)
        else:
            obss, actions, rewards, aug_rewards, dones = \
                world_model.rollout(data['next_obs'], self._actor_fn, envstep)
            obss = torch.cat([data['obs'].unsqueeze(0), obss])
            actions = torch.cat([data['action'].unsqueeze(0), actions])
            rewards = torch.cat([data['reward'].unsqueeze(0), rewards])
            aug_rewards = torch.cat([torch.zeros_like(data['reward']).unsqueeze(0), aug_rewards])
            dones = torch.cat([data['done'].unsqueeze(0), dones])

        dones = torch.cat([torch.zeros_like(data['done']).unsqueeze(0), dones])

        # (T+1, B)
        target_q_values = q_evaluation(obss, actions, partial(self._critic_fn, model=self._target_model))
        if self._twin_critic:
            target_q_values = torch.min(target_q_values[0], target_q_values[1]) + aug_rewards
        else:
            target_q_values = target_q_values + aug_rewards

        # (T, B)
        lambda_return = generalized_lambda_returns(target_q_values, rewards, self._gamma, self._lambda, dones[1:])

        # (T, B)
        # If S_t terminates, we should not consider loss from t+1,...
        weight = (1 - dones[:-1].detach()).cumprod(dim=0)

        # (T+1, B)
        q_values = q_evaluation(obss.detach(), actions.detach(), partial(self._critic_fn, model=self._learn_model))
        if self._twin_critic:
            critic_loss = 0.5 * torch.square(q_values[0][:-1] - lambda_return.detach()) \
                        + 0.5 * torch.square(q_values[1][:-1] - lambda_return.detach())
        else:
            critic_loss = 0.5 * torch.square(q_values[:-1] - lambda_return.detach())

        # value expansion loss
        critic_loss = (critic_loss * weight).mean()

        # value gradient loss
        policy_loss = -(lambda_return * weight).mean()

        # alpha_loss  = None

        loss_dict = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
            # 'alpha_loss':  alpha_loss.detach(),
        }

        norm_dict = self._update(loss_dict)

        # =============
        # after update
        # =============
        self._forward_learn_cnt += 1
        # target update
        self._target_model.update(self._learn_model.state_dict())

        return {
            'cur_lr_q': self._optimizer_q.defaults['lr'],
            'cur_lr_p': self._optimizer_policy.defaults['lr'],
            'alpha': self._alpha.item(),
            'target_q_value': target_q_values.detach().mean().item(),
            **norm_dict,
            **loss_dict,
        }

    def _update(self, loss_dict):
        # update critic
        self._optimizer_q.zero_grad()
        loss_dict['critic_loss'].backward()
        critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip)
        self._optimizer_q.step()
        # update policy
        self._optimizer_policy.zero_grad()
        loss_dict['policy_loss'].backward()
        policy_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip)
        self._optimizer_policy.step()
        # update temperature
        # self._alpha_optim.zero_grad()
        # loss_dict['alpha_loss'].backward()
        # self._alpha_optim.step()
        return {'policy_norm': policy_norm, 'critic_norm': critic_norm}

    def _monitor_vars_learn(self) -> List[str]:
        r"""
        Overview:
            Return variables' name if variables are to used in monitor.
        Returns:
            - vars (:obj:`List[str]`): Variables' name list.
        """
        alpha_loss = ['alpha_loss'] if self._auto_alpha else []
        return [
            'policy_loss',
            'critic_loss',
            'policy_norm',
            'critic_norm',
            'cur_lr_q',
            'cur_lr_p',
            'alpha',
            'target_q_value',
        ] + alpha_loss


@POLICY_REGISTRY.register('stevesac')
class STEVESACPolicy(SACPolicy):
    r"""
       Overview:
           Model based SAC with stochastic value expansion (arXiv 1807.01675).\
           This implementation also uses value gradient w.r.t the same STEVE target.

           https://arxiv.org/pdf/1807.01675.pdf

       Config:
           == ====================    ========    =============  =====================================
           ID Symbol                  Type        Default Value  Description
           == ====================    ========    =============  =====================================
           1  ``learn.grad_clip`      float       100.0          | Max norm of gradients.
           2  ``learn.ensemble_size`` int         1              | The number of ensemble world models.
           == ====================    ========    =============  =====================================

        .. note::
            For other configs, please refer to ding.policy.sac.SACPolicy.
       """

    config = dict(
        learn=dict(
            # (float) Max norm of gradients.
            grad_clip=100,
            # (int) The number of ensemble world models.
            ensemble_size=1,
        )
    )

    def _init_learn(self) -> None:
        super()._init_learn()
        self._target_model.requires_grad_(False)

        self._grad_clip = self._cfg.learn.grad_clip
        self._ensemble_size = self._cfg.learn.ensemble_size
        self._auto_alpha = self._cfg.learn.auto_alpha
        # TODO: auto alpha
        assert not self._auto_alpha, "NotImplemented"

        def actor_fn(obs: Tensor):
            obs, dim = fold_batch(obs, 1)
            (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit']
            dist = Independent(Normal(mu, sigma), 1)
            pred = dist.rsample()
            action = torch.tanh(pred)

            log_prob = dist.log_prob(
                pred
            ) + 2 * (pred + torch.nn.functional.softplus(-2. * pred) - torch.log(torch.tensor(2.))).sum(-1)
            aug_reward = -self._alpha.detach() * log_prob

            return unfold_batch(action, dim), unfold_batch(aug_reward, dim)

        self._actor_fn = actor_fn

        def critic_fn(obss: Tensor, actions: Tensor, model: nn.Module):
            eval_data = {'obs': obss, 'action': actions}
            q_values = model.forward(eval_data, mode='compute_critic')['q_value']
            return q_values

        self._critic_fn = critic_fn
        self._forward_learn_cnt = 0

    def _forward_learn(self, data: dict, world_model, envstep) -> Dict[str, Any]:
        # preprocess data
        data = default_preprocess_learn(
            data,
            use_priority=self._priority,
            use_priority_IS_weight=self._cfg.priority_IS_weight,
            ignore_done=self._cfg.learn.ignore_done,
            use_nstep=False
        )
        if self._cuda:
            data = to_device(data, self._device)

        if len(data['action'].shape) == 1:
            data['action'] = data['action'].unsqueeze(1)

        # [B, D] -> [E, B, D]
        data['next_obs'] = unsqueeze_repeat(data['next_obs'], self._ensemble_size)
        data['reward'] = unsqueeze_repeat(data['reward'], self._ensemble_size)
        data['done'] = unsqueeze_repeat(data['done'], self._ensemble_size)

        self._learn_model.train()
        self._target_model.train()

        obss, actions, rewards, aug_rewards, dones = \
            world_model.rollout(data['next_obs'], self._actor_fn, envstep, keep_ensemble=True)
        rewards = torch.cat([data['reward'].unsqueeze(0), rewards])
        dones = torch.cat([data['done'].unsqueeze(0), dones])

        # (T, E, B)
        target_q_values = q_evaluation(obss, actions, partial(self._critic_fn, model=self._target_model))
        if self._twin_critic:
            target_q_values = torch.min(target_q_values[0], target_q_values[1]) + aug_rewards
        else:
            target_q_values = target_q_values + aug_rewards

        # (T+1, E, B)
        discounts = ((1 - dones) * self._gamma).cumprod(dim=0)
        discounts = torch.cat([torch.ones_like(discounts)[:1], discounts])
        # (T, E, B)
        cum_rewards = (rewards * discounts[:-1]).cumsum(dim=0)
        discounted_q_values = target_q_values * discounts[1:]
        steve_return = cum_rewards + discounted_q_values
        # (T, B)
        steve_return_mean = steve_return.mean(1)
        with torch.no_grad():
            steve_return_inv_var = 1 / (1e-8 + steve_return.var(1, unbiased=False))
            steve_return_weight = steve_return_inv_var / (1e-8 + steve_return_inv_var.sum(dim=0))
        # (B, )
        steve_return = (steve_return_mean * steve_return_weight).sum(0)

        eval_data = {'obs': data['obs'], 'action': data['action']}
        q_values = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']
        if self._twin_critic:
            critic_loss = 0.5 * torch.square(q_values[0] - steve_return.detach()) \
                        + 0.5 * torch.square(q_values[1] - steve_return.detach())
        else:
            critic_loss = 0.5 * torch.square(q_values - steve_return.detach())

        critic_loss = critic_loss.mean()

        policy_loss = -steve_return.mean()

        # alpha_loss  = None

        loss_dict = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
            # 'alpha_loss':  alpha_loss.detach(),
        }

        norm_dict = self._update(loss_dict)

        # =============
        # after update
        # =============
        self._forward_learn_cnt += 1
        # target update
        self._target_model.update(self._learn_model.state_dict())

        return {
            'cur_lr_q': self._optimizer_q.defaults['lr'],
            'cur_lr_p': self._optimizer_policy.defaults['lr'],
            'alpha': self._alpha.item(),
            'target_q_value': target_q_values.detach().mean().item(),
            **norm_dict,
            **loss_dict,
        }

    def _update(self, loss_dict):
        # update critic
        self._optimizer_q.zero_grad()
        loss_dict['critic_loss'].backward()
        critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip)
        self._optimizer_q.step()
        # update policy
        self._optimizer_policy.zero_grad()
        loss_dict['policy_loss'].backward()
        policy_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip)
        self._optimizer_policy.step()
        # update temperature
        # self._alpha_optim.zero_grad()
        # loss_dict['alpha_loss'].backward()
        # self._alpha_optim.step()
        return {'policy_norm': policy_norm, 'critic_norm': critic_norm}

    def _monitor_vars_learn(self) -> List[str]:
        r"""
        Overview:
            Return variables' name if variables are to used in monitor.
        Returns:
            - vars (:obj:`List[str]`): Variables' name list.
        """
        alpha_loss = ['alpha_loss'] if self._auto_alpha else []
        return [
            'policy_loss',
            'critic_loss',
            'policy_norm',
            'critic_norm',
            'cur_lr_q',
            'cur_lr_p',
            'alpha',
            'target_q_value',
        ] + alpha_loss
